#
# Copyright 2022- IBM Inc. All rights reserved
# SPDX-License-Identifier: Apache2.0
#

# ==================================================================================================
# IMPORTS
# ==================================================================================================
import csv
import datetime
import time
from copy import copy
from operator import itemgetter
import os
import shutil

import torch
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
from dotmap import DotMap

import progressbar
import tqdm
from .util import Logger
from .model import *
# from lib.util import csv2dict, loadmat
from .torch_blocks import *
from .confusion_support import plot_confusion_support, avg_sim_confusion
import os.path
import pdb
from torchvision import transforms
from .PCA import *
from .RandMix import RandMix
from .sinkhorn_distance import SinkhornDistance


from .dataloader.FSCIL.data_utils import *





def mixup_data(x, y, device_idx, alpha=1.0, use_cuda=True, ):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda(device_idx)
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2


def cutmix_data(input_x,input_y,device_idx,beta=1.0):


    lam = np.random.beta(beta, beta)
    rand_index = torch.randperm(input_x.size()[0]).cuda(device_idx)

    target_a = input_y
    target_b = input_y[rand_index]

    bbx1, bby1, bbx2, bby2 = rand_bbox(input_x.size(), lam)
    input_x[:, :, bbx1:bbx2, bby1:bby2] = input_x[rand_index, :, bbx1:bbx2, bby1:bby2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input_x.size()[-1] * input_x.size()[-2]))

    return input_x, target_a, target_b, lam




# def pretrain_baseFSCIL(verbose, **parameters):
#     '''
#     Pre-training on base session
#     '''
#     args = DotMap(parameters)
#     # args.gpu = 4
#     writer = SummaryWriter(args.log_dir)
#
#     # Initialize the dataset generator and the model
#     args = set_up_datasets(args)
#     trainset, train_loader, val_loader = get_base_dataloader(args)
#
#     model = KeyValueNetwork(args)
#
#     model.mode = 'pretrain'
#     # Store all parameters in a variable
#     parameters_list, parameters_table = process_dictionary(parameters)
#     logs_dir = os.path.join(args.log_dir + '/' + 'train_log.txt')
#
#     # Print all parameters
#     if verbose:
#         print("Parameters:")
#         for key, value in parameters_list:
#             print("\t{}".format(key).ljust(40) + "{}".format(value))
#             with open(logs_dir, 'a', encoding='utf-8') as f1:
#                 f1.write("\t{}".format(key).ljust(40) + "{}".format(value)+'\n')
#
#
#     criterion = nn.CrossEntropyLoss()
#
#     if args.gpu is not None:
#         t.cuda.set_device(args.gpu)
#         model = model.cuda(args.gpu)
#         criterion = criterion.cuda(args.gpu)
#
#     for param in model.embedding.conv1.parameters():
#         param.requires_grad = False
#
#
#
#     optimizer = t.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
#                             lr=args.learning_rate,nesterov=args.SGDnesterov,
#                             weight_decay=args.SGDweight_decay, momentum=args.SGDmomentum)
#     scheduler = t.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=0.1)
#
#
#     best_acc1 = 0
#
#     for epoch in tqdm.tqdm(range(1, args.max_train_iter), desc='Epoch'):
#         global_count = 0
#         losses = AverageMeter('Loss')
#         acc = AverageMeter('Acc@1')
#         model.train(True)
#
#
#         for i, batch in enumerate(train_loader):
#             global_count = global_count + 1
#             data, train_label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
#             # data, data_aug1, data_aug2, train_label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
#             # forward pass
#             optimizer.zero_grad()
#             aug_p = torch.rand(1)
#
#             output = model(data)
#             loss_cls = criterion(output, train_label)
#             proxy = model.classifier
#             features = model.fea_rep
#             loss_pcl = PCLoss(num_classes=args.base_class, scale=12)(features, train_label, proxy)
#             loss = loss_cls + args.pcl_weight *  loss_pcl
#
#             # Backpropagation
#             loss.backward()
#             optimizer.step()
#
#             losses.update(loss.item(), data.size(0))
#
#
#         scheduler.step()
#
#         # write to tensorboard
#         writer.add_scalar('training_loss/pretrain_CEL', losses.avg, epoch)
#         writer.add_scalar('accuracy/pretrain_train', acc.avg, epoch)
#
#         val_loss, val_acc_mean, _ = validation(model, criterion, val_loader, args)
#         writer.add_scalar('validation_loss/pretrain_CEL', val_loss, epoch)
#         writer.add_scalar('accuracy/pretrain_val', val_acc_mean, epoch)
#
#         is_best = val_acc_mean > best_acc1
#         best_acc1 = max(val_acc_mean, best_acc1)
#
#         with open(logs_dir, 'a', encoding='utf-8') as f1:
#             f1.write(f'epoch: {epoch} current mean ac: {val_acc_mean} best acc: {best_acc1:0.5f}\n')
#
#         print('epoch:', epoch, 'current mean acc', val_acc_mean, 'best acc:', best_acc1)
#         save_checkpoint({
#             'train_iter': epoch + 1,
#             'arch': args.block_architecture,
#             'state_dict': model.state_dict(),
#             'best_acc1': best_acc1,
#             'optimizer': optimizer.state_dict(),
#         }, is_best, savedir=args.log_dir)
#
#     writer.close()



def pretrain_baseFSCIL(verbose, **parameters):
    '''
    Pre-training on base session
    '''
    args = DotMap(parameters)
    # args.gpu = 4
    writer = SummaryWriter(args.log_dir)

    # Initialize the dataset generator and the model
    args = set_up_datasets(args)
    trainset, train_loader, val_loader = get_base_dataloader(args)

    model = KeyValueNetwork(args)

    model.mode = 'pretrain'
    # Store all parameters in a variable
    parameters_list, parameters_table = process_dictionary(parameters)
    logs_dir = os.path.join(args.log_dir + '/' + 'train_log.txt')

    # Print all parameters
    if verbose:
        print("Parameters:")
        for key, value in parameters_list:
            print("\t{}".format(key).ljust(40) + "{}".format(value))
            with open(logs_dir, 'a', encoding='utf-8') as f1:
                f1.write("\t{}".format(key).ljust(40) + "{}".format(value)+'\n')


    criterion = nn.CrossEntropyLoss()

    if args.gpu is not None:
        t.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        criterion = criterion.cuda(args.gpu)

    for param in model.embedding.conv1.parameters():
        param.requires_grad = False



    optimizer = t.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
                            lr=args.learning_rate,nesterov=args.SGDnesterov,
                            weight_decay=args.SGDweight_decay, momentum=args.SGDmomentum)
    scheduler = t.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=0.1)


    best_acc1 = 0

    for epoch in tqdm.tqdm(range(1, args.max_train_iter), desc='Epoch'):
        global_count = 0
        losses = AverageMeter('Loss')
        acc = AverageMeter('Acc@1')
        model.train(True)


        for i, batch in enumerate(train_loader):
            global_count = global_count + 1
            # data, train_label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
            data, data_aug1, data_aug2, train_label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
            # forward pass
            optimizer.zero_grad()
            aug_p = torch.rand(1)

            if aug_p < 0.33:
                inputs_aug, targets_a_aug, targets_b_aug, lam = mixup_data(data, train_label, device_idx=args.gpu)

                all_x = torch.cat([inputs_aug, data])
                all_y = model(all_x)
                loss_cls = lam * criterion(all_y[:data.size()[0]], targets_a_aug) + (1 - lam) * criterion(
                    all_y[:data.size()[0]], targets_b_aug) + criterion(all_y[data.size()[0]:], train_label)

                proxy = model.classifier
                features = model.fea_rep[data.size()[0]:]
                loss_pcl = PCLoss(num_classes=args.base_class, scale=12)(features, train_label, proxy)

            elif 0.33 < aug_p < 0.66:
                inputs_aug, targets_a_aug, targets_b_aug, lam = cutmix_data(data, train_label, device_idx=args.gpu)
                all_x = torch.cat([inputs_aug, data])
                all_y = model(all_x)
                loss_cls = lam * criterion(all_y[:data.size()[0]], targets_a_aug) + (1 - lam) * criterion(
                    all_y[:data.size()[0]], targets_b_aug) + criterion(all_y[data.size()[0]:], train_label)
                proxy = model.classifier
                features = model.fea_rep[data.size()[0]:]
                loss_pcl = PCLoss(num_classes=args.base_class, scale=12)(features, train_label, proxy)

            else:
                all_x = torch.cat([data, data_aug1, data_aug2])
                all_y = model(all_x)

                targets = train_label
                logits_clean, logits_aug1, logits_aug2 = torch.split(
                    all_y, data.size(0))

                # Cross-entropy is only computed on clean images
                loss_cls = criterion(logits_clean, targets)

                p_clean, p_aug1, p_aug2 = F.softmax(logits_clean, dim=1), F.softmax(logits_aug1, dim=1), F.softmax(
                    logits_aug2, dim=1)

                # Clamp mixture distribution to avoid exploding KL divergence
                p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log()
                loss_kl = 12 * (F.kl_div(p_mixture, p_clean, reduction='batchmean') +
                                F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
                                F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.

                loss_cls += loss_kl
                proxy = model.classifier
                features = model.fea_rep[:data.size()[0]]
                loss_pcl = PCLoss(num_classes=args.base_class, scale=12)(features, train_label, proxy)



            loss = loss_cls + args.pcl_weight *  loss_pcl

            # Backpropagation
            loss.backward()
            optimizer.step()

            losses.update(loss.item(), data.size(0))


        scheduler.step()

        # write to tensorboard
        writer.add_scalar('training_loss/pretrain_CEL', losses.avg, epoch)
        writer.add_scalar('accuracy/pretrain_train', acc.avg, epoch)

        val_loss, val_acc_mean, _ = validation(model, criterion, val_loader, args)
        writer.add_scalar('validation_loss/pretrain_CEL', val_loss, epoch)
        writer.add_scalar('accuracy/pretrain_val', val_acc_mean, epoch)

        is_best = val_acc_mean > best_acc1
        best_acc1 = max(val_acc_mean, best_acc1)

        with open(logs_dir, 'a', encoding='utf-8') as f1:
            f1.write(f'epoch: {epoch} current mean ac: {val_acc_mean} best acc: {best_acc1:0.5f}\n')

        print('epoch:', epoch, 'current mean acc', val_acc_mean, 'best acc:', best_acc1)
        save_checkpoint({
            'train_iter': epoch + 1,
            'arch': args.block_architecture,
            'state_dict': model.state_dict(),
            'best_acc1': best_acc1,
            'optimizer': optimizer.state_dict(),
        }, is_best, savedir=args.log_dir)

    writer.close()




def metatrain_baseFSCIL(verbose, **parameters):
    '''
    Meta-training on base session
    '''

    # Argument Preparation
    args = DotMap(parameters)
    # args.gpu = 5
    writer = SummaryWriter(args.log_dir)

    # Initialize the dataset generator and the model
    args = set_up_datasets(args)
    trainset, train_loader, val_loader = get_base_dataloader2(args)

    trainset_clean, train_loader_clean, val_loader_clean = get_clean_dataloader(args)

    model = KeyValueNetwork(args)

    model.mode = 'pretrain'
    # Store all parameters in a variable
    parameters_list, parameters_table = process_dictionary(parameters)

    logs_dir = os.path.join(args.log_dir + '/' + 'metatrain_log.txt')
    # Print all parameters
    if verbose:
        print("Parameters:")
        for key, value in parameters_list:
            print("\t{}".format(key).ljust(40) + "{}".format(value))
            with open(logs_dir, 'a', encoding='utf-8') as f1:
                f1.write("\t{}".format(key).ljust(40) + "{}".format(value) + '\n')



    # Take start time
    start_time = time.time()

    criterion = nn.CrossEntropyLoss()

    if args.gpu is not None:
        t.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        criterion = criterion.cuda(args.gpu)


    # for param in model.embedding.parameters():
    #     param.requires_grad = True
    #
    # model.classifier.requires_grad=False
    # optimizer = t.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
    #                         lr=args.learning_rate,nesterov=args.SGDnesterov,
    #                         weight_decay=args.SGDweight_decay, momentum=args.SGDmomentum)

    optimizer = t.optim.SGD(model.parameters(),
                            lr=args.learning_rate,nesterov=args.SGDnesterov,
                            weight_decay=args.SGDweight_decay, momentum=args.SGDmomentum)



    scheduler = t.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=0.1)


    model, optimizer, scheduler, start_train_iter, best_acc1= load_checkpoint(model, optimizer, scheduler, args)
    best_acc1 = 0

    # plot_confusion_support(model.classifier.data.cpu(),
    #                        savepath="{:}/session{:}".format(args.log_dir, 'cos_i_p'))
    # train_iterator = iter(train_loader)


    for i in tqdm.tqdm(range(1, args.max_train_iter), desc='Epoch'):
        global_count = 0

        losses = AverageMeter('Loss')
        acc = AverageMeter('Acc@1')

        model.train(True)


        for j, batch in enumerate(train_loader):
            global_count = global_count + 1
            # data, data_aug, train_label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
            # data, data_aug1,data_aug2, train_label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
            data, train_label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
            optimizer.zero_grad()
            aug_p = torch.rand(1)



            # output = model(data)
            # loss_cls = criterion(output, train_label)
            # proxy = model.classifier
            # features = model.fea_rep
            # loss_pcl = PCLoss(num_classes=args.base_class, scale=12)(features, train_label, proxy)

            if aug_p<0.1:
                # data = normalize(data)
                inputs_aug, targets_a_aug, targets_b_aug, lam = mixup_data(data, train_label, device_idx=args.gpu)

                all_x = torch.cat([inputs_aug, data])
                all_y = model(all_x)
                loss_cls = lam * criterion(all_y[:data.size()[0]], targets_a_aug) + (1 - lam) * criterion(
                    all_y[:data.size()[0]], targets_b_aug) + criterion(all_y[data.size()[0]:], train_label)

                proxy = model.classifier
                features = model.fea_rep[data.size()[0]:]
                loss_pcl = PCLoss(num_classes=args.base_class, scale=12)(features, train_label, proxy)

            else:
                output = model(data)
                loss_cls = criterion(output, train_label)
                proxy = model.classifier
                features = model.fea_rep
                loss_pcl = PCLoss(num_classes=args.base_class, scale=12)(features, train_label, proxy)


            loss = loss_cls + args.pcl_weight*loss_pcl

            # Backpropagation
            loss.backward()
            optimizer.step()

            losses.update(loss.item(), data.size(0))
        scheduler.step()

        val_loss, val_acc_mean, _ = validation(model, criterion, val_loader, args)
        writer.add_scalar('validation_loss/log_loss', val_loss, i)
        writer.add_scalar('accuracy/validation', val_acc_mean, i)

        is_best = val_acc_mean > best_acc1
        best_acc1 = max(val_acc_mean, best_acc1)

        if is_best:
            model.eval()

            prototype, cov, classlabel = model.protoSave(train_loader_clean)
        else:
            prototype, cov, classlabel = None,None,None


        with open(logs_dir, 'a', encoding='utf-8') as f1:
            f1.write(f'epoch: {i} current mean ac: {val_acc_mean} best acc: {best_acc1:0.5f} loss: {losses.avg}\n')
        print('epoch:', i, 'current mean acc', val_acc_mean, 'best acc:', best_acc1,'loss', losses.avg)

        save_checkpoint({
            'train_iter': i + 1,
            'arch': args.block_architecture,
            'state_dict': model.state_dict(),
            'best_acc1': best_acc1,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'prototype' : prototype,
            'cov': cov,
            'classlabel': classlabel,


        }, is_best, savedir=args.log_dir)

    writer.close()




def train_FSCIL(verbose=False, **parameters):
    '''
    Main FSCIL evaluation on all sessions
    '''
    args = DotMap(parameters)
    args = set_up_datasets(args)
    # args.gpu = 6

    model = KeyValueNetwork(args, mode="pretrain")

    # Store all parameters in a variable
    parameters_list, parameters_table = process_dictionary(parameters)

    # Print all parameters
    if verbose:
        print("Parameters:")
        for key, value in parameters_list:
            print("\t{}".format(key).ljust(40) + "{}".format(value))

    # Write parameters to file
    if not args.inference_only:
        filename = args.log_dir + '/parameters.csv'
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        # retrain
        with open(filename, 'w') as csv_file:
            writer = csv.writer(csv_file)
            keys, values = zip(*parameters_list)
            writer.writerow(keys)
            writer.writerow(values)

    writer = SummaryWriter(args.log_dir)

    criterion = nn.CrossEntropyLoss()

    if args.gpu is not None:
        t.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        criterion = criterion.cuda(args.gpu)

    # set all parameters except FC to trainable false
    for param in model.parameters():
        param.requires_grad = False
    for param in model.embedding.fc.parameters():
        param.requires_grad = True

    model.classifier.requires_grad = False

    optimizer = t.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
                            lr=args.learning_rate, nesterov=args.SGDnesterov,
                            weight_decay=args.SGDweight_decay, momentum=args.SGDmomentum)

    scheduler = t.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=0.1)

    # model, optimizer, scheduler, start_train_iter, best_acc1= load_checkpoint(model, optimizer, scheduler, args)
    # model, optimizer, scheduler, start_train_iter, best_acc1, prototype, cov, classlabel = load_checkpoint2(model,
    #                                                                                                         optimizer,
    #                                                                                                         scheduler,
    #                                                                                                         args)
    model, optimizer, scheduler, start_train_iter, best_acc1 = load_checkpoint(model, optimizer, scheduler, args)
    # plot_confusion_support(model.classifier.data.cpu(),
    #                        savepath="{:}/session{:}".format(args.log_dir, 'classifier_confusion2'))


    logs_dir = os.path.join(args.log_dir + '/' + 'test_log.txt')
    all_acc = []
    all_acc_new = []
    prototype = model.classifier.data.cpu().numpy()
    for session in range(1):
        nways_session = args.base_class + session * args.way

        if session > 0:
            model.mode ='meta'

        train_set, train_loader, test_loader, test_loader_new = get_dataloader(args, session)


        # update model
        batch = next(iter(train_loader))
        # original_align(model, batch, optimizer, args, writer, session, nways_session,prototype, cov)



        _, clean_loader, _ = get_clean_dataloader(args)

        data_list = []  # 存储所有样本的特征向量
        labels = []  # 存储所有样本的标签

        model.eval()
        with t.no_grad():
            for x, target in clean_loader:
                x = x.cuda(args.gpu, non_blocking=True)
                all_feature = model.embedding(x)
                target = target.cuda(args.gpu, non_blocking=True)

                # 将特征向量和标签添加到数据列表中
                data_list.append(all_feature.cpu().numpy())
                labels.append(target.cpu().numpy())

        # 将数据列表转换为 NumPy 数组
        data_array = np.concatenate(data_list, axis=0)
        labels_array = np.concatenate(labels, axis=0)


        # 使用 t-SNE 进行降维
        tsne = TSNE(n_components=2, random_state=42)
        tsne_result = tsne.fit_transform(data_array)
        print(tsne_result.shape)
        # 可视化 t-SNE 结果
        plt.figure(figsize=(8, 6),dpi=500)
        plt.scatter(tsne_result[:, 0], tsne_result[:, 1], c=labels_array, cmap=plt.cm.get_cmap("jet", 10), marker='o')
        # plt.legend()
        plt.colorbar()
        plt.xticks([])
        plt.yticks([])
        # plt.title('t-SNE Visualization')

        save_path = args.log_dir + 'tsne_plot_cub200.pdf'
        # 保存图为 PDF 文件
        plt.savefig(save_path, format='pdf')



        if args.retrain_iter == 0:
            original_align(model, batch, optimizer, args, writer, session, nways_session,prototype, cov)

        else:

            # proto_align_v2(model, batch, optimizer, args, writer, session, nways_session, prototype, cov)
            # proto_align_final(model, batch, optimizer, args, writer, session, nways_session, prototype, cov)



            base_proto, base_cov, acc_each_session= feat_replay(model, batch, optimizer, args, writer, session, nways_session,
                                                  prototype, cov, test_loader, best_acc1)
            cov = base_cov
            prototype = base_proto
            all_acc.append(acc_each_session)

        # loss, acc, conf_fig = validation(model, criterion, test_loader, args, nways_session)
        #
        #
        # if test_loader_new is not None:
        #     model.eval()
        #     acc_new = AverageMeter('Acc@1', ':6.2f')
        #     with t.no_grad():
        #         for i, batch in enumerate(test_loader_new):
        #             data, label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
        #             output = model(data)
        #             accuracy = top1accuracy(output.argmax(dim=1), label)
        #
        #             acc_new.update(accuracy.item(), data.size(0))
        #     all_acc_new.append(acc_new.avg)
        #
        #
        # print("Session {:}: {:.2f}%".format(session, acc))
        # all_acc.append(acc)
        # writer.add_scalar('accuracy/cont', acc, session)
        #
        #
        # acc_up2now = []
        #
        # for i in range(session + 1):
        #     if i == 0:
        #         classes = np.arange(args.num_classes)[:args.base_class]
        #     else:
        #         classes = np.arange(args.num_classes)[(args.base_class + (i - 1) * args.way):(args.base_class + i * args.way)]
        #
        #     test_for_each =  args.Dataset.CUB200(root=args.data_folder, train=False, index=classes)
        #     testloader2 = torch.utils.data.DataLoader(dataset=test_for_each, batch_size=args.batch_size_inference, shuffle=False, pin_memory=True)
        #     model.eval()
        #     acc2 = AverageMeter('Acc@1', ':6.2f')
        #
        #     with t.no_grad():
        #         for i, batch in enumerate(testloader2):
        #             data, label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
        #             output = model(data)
        #             accuracy = top1accuracy(output.argmax(dim=1), label)
        #
        #             acc2.update(accuracy.item(), data.size(0))
        #     acc_up2now.append(acc2.avg)
        # print(acc_up2now, all_acc_new)
        #
        # if session == 0:
        #     with open(logs_dir, 'a', encoding='utf-8') as f1:
        #         f1.write(f'\nSource Model:{args.resume}\n')
        #         f1.write(f'{acc_up2now}\t{acc}\n')
        # else:
        #     with open(logs_dir, 'a', encoding='utf-8') as f1:
        #         f1.write(f'{acc_up2now}\Avg Acc:{acc}\t Novel classes Avg Acc :{all_acc_new}\n')



        if session == args.sessions - 1:
            mean_acc = np.mean(all_acc)
            with open(logs_dir, 'a', encoding='utf-8') as f1:
                f1.write(f'Mean Acc for this run is: {mean_acc}\t Each Session Acc is{all_acc} \n')
                print((f'Mean Acc for this run is: {mean_acc}\n Each Session Acc is{all_acc} \n'))
    writer.close()



def original_align(model, data, optimizer, args, writer, session, nways_session, prototype, cov):
    '''
    Alignment of FC using MSE Loss and feature replay
    '''

    losses = AverageMeter('Loss')
    criterion = myCosineLoss(args.retrain_act)
    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)

    # Stage 1: Compute feature representation of new data
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)

            target = target.cuda(args.gpu, non_blocking=True)
            model.update_feat_replay(x, target)






    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()
    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)


    # Stage 3: Nuddging
    model.nudge_prototypes(nways_session, writer, session, args.gpu)


    # Bipolarize prototypes in Mode 2
    if args.bipolarize_prototypes:
        # nways_session = args.base_class + session * args.way
        # oways_session = args.base_class + (session - 1) * args.way
        model.bipolarize_prototypes()




    model.embedding.fc.train()


    if session>0:
        for epoch in range(args.retrain_iter):
            optimizer.zero_grad()
            support = model.get_support_feat(feat)
            loss = criterion(support[:nways_session], model.key_mem.data[:nways_session])

            # Backpropagation
            loss.backward()
            optimizer.step()

            writer.add_scalar('retraining/loss_sess{:}'.format(session), loss.item(), epoch)


    # Stage 5: Fill up prototypes again
    model.eval()
    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # Stage 6: Optional EM compression
    if args.em_compression == "hrr":
        model.hrr_superposition(nways_session, args.em_compression_nsup)



def proto_align_v5(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov, test_loader, best_acc1, test_loader_new):
    ######use saved mean and covariance perform clustering##########
    losses = AverageMeter('Loss')
    criterion = myCosineLoss(args.retrain_act)
    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)
    sinkhorn = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)



    classifier = MultiClassLogisticRegression(input_dim=args.dim_features, output_dim=nways_session).cuda(args.gpu)
    acc = AverageMeter('Acc@1', ':6.2f')
    # Stage 1: Compute feature representation of new data
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
            x_features = model.embedding(x)

            model.update_feat_replay(x, target)



    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()

    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]
    # Stage 3: Nuddging
    # model.nudge_prototypes(nways_session, writer, session, args.gpu)
    logs_dir = os.path.join(args.log_dir + '/' + 'test_log.txt')
    if session == 0:
        with open(logs_dir, 'a', encoding='utf-8') as f1:
            f1.write(f'\nSource Model:{args.resume}\n')
            f1.write(f'{best_acc1}\t{best_acc1}\n')

    if session > 0:
        nways_session = args.base_class + session * args.way
        oways_session = args.base_class + (session - 1) * args.way

        c_proto = model.key_mem.data

        # base_torch = model.key_mem.data[:oways_session]

        base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)

        cost, Pi, C = sinkhorn(base_torch[:args.base_class], c_proto[args.base_class:nways_session])


        c_proto = c_proto.cpu().numpy()

        # base_prototype = c_proto[:oways_session]

        cov_saver = base_cov
        prototype_saver = base_prototype

        sampled_feature_old,sampled_label_old  = [], []
        sample_num_old =100

        for idx in range(args.base_class):
            sampled_feature_old.append(np.random.multivariate_normal(mean=base_prototype[idx], cov=base_cov[idx], size=sample_num_old))
            sampled_label_old.extend([idx] *sample_num_old)


        sampled_feature_old = np.array(sampled_feature_old).reshape(args.base_class*sample_num_old, -1)
        sampled_label_old = np.array(sampled_label_old)

        sampled_feature_new = []
        sampled_label_new = []
        sample_num = 100
        for i in range(args.base_class,nways_session):

            mean, cov = distribution_calibration_dan(c_proto[i], Pi[:, i-args.base_class], base_prototype[:args.base_class], base_cov[:args.base_class],
                                                     n_lsamples=args.way)

            sampled_feature_new.append(np.random.multivariate_normal(mean=mean, cov=cov, size=sample_num))

            sampled_label_new.extend([i] *sample_num)
            cov_saver = np.concatenate([cov_saver, np.expand_dims(cov,0)])
            prototype_saver = np.concatenate([prototype_saver, np.expand_dims(mean,0)])

        sampled_feature_new = np.array(sampled_feature_new).reshape((nways_session-args.base_class)*sample_num, -1)
        sampled_label_new = np.array(sampled_label_new)

        sampled_feature_all = np.concatenate([sampled_feature_old, sampled_feature_new], axis=0)
        sampled_label_all = np.concatenate([sampled_label_old, sampled_label_new], axis=0)

        sampled_feature_all = torch.from_numpy(sampled_feature_all).cuda(args.gpu).float()
        sampled_feature_all = torch.concat([sampled_feature_all,x_features],dim=0)

        sampled_label_all = torch.from_numpy(sampled_label_all).cuda(args.gpu)
        sampled_label_all = torch.concat([sampled_label_all,target], dim=0)

        # sampled_feature_new = sampled_feature_new.reshape(5 * sample_num, -1)
        # sampled_feature_new = torch.from_numpy(sampled_feature_new).cuda(args.gpu).float()
        # sampled_label_new = torch.from_numpy(sampled_label_new).cuda(args.gpu)
        # sampled_feature_new = torch.concat([sampled_feature_new,x_features],dim=0)
        # sampled_label_new = torch.concat([sampled_label_new, target], dim=0)

        num_epochs = 10000
        optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, nesterov=args.SGDnesterov,
                            weight_decay=args.SGDweight_decay, momentum=args.SGDmomentum)
        classifier.train()

        all_training_num = sample_num_old*args.base_class + sample_num * (nways_session-args.base_class)
        old_class_weight = (sample_num_old*args.base_class)/ all_training_num
        new_class_weight =  (sample_num * (nways_session-args.base_class)) /all_training_num

        for epoch in range(num_epochs):

            outputs = classifier(sampled_feature_all)

            # proxy = classifier.linear
            # features = sampled_feature_all
            # loss_pcl = PCLoss(num_classes=nways_session, scale=12)(features, sampled_label_all, proxy)
            old_class_weights = torch.ones(args.base_class).cuda(args.gpu)
            novel_class_weights = torch.ones(nways_session-args.base_class).cuda(args.gpu)
            weights = torch.cat([old_class_weights,novel_class_weights])
            loss = nn.CrossEntropyLoss(weight=weights)(outputs, sampled_label_all)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # classifier.linear.data[:args.base_class] = 0.5*classifier.linear.data[:args.base_class]+0.5*model.classifier.data
        classifier.linear.data[:args.base_class] =  model.classifier.data
        model.eval()
        classifier.eval()


        logs_dir = os.path.join(args.log_dir + '/' + 'test_log.txt')
        all_acc,acc_up2now,all_acc_new = [], [],[]

        with t.no_grad():
            for i, batch in enumerate(test_loader):
                data, label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]

                features = model.embedding(data)

                # query_features = features.cpu().numpy()
                predicts = classifier(features)

                # predicts = torch.from_numpy(predicts).cuda(args.gpu)
                accuracy = top1accuracy(predicts.argmax(dim=1), label)
                # losses.update(loss.item(), data.size(0))
                acc.update(accuracy.item(), data.size(0))

        acc_each_session = acc.avg
        print("Session {:} Testing Acc: {:.2f}%".format(session, acc_each_session))
        all_acc.append(acc_each_session)


        if test_loader_new is not None:
            model.eval()
            acc_new = AverageMeter('Acc@1', ':6.2f')
            with t.no_grad():
                for i, batch in enumerate(test_loader_new):
                    data, label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]

                    features = model.embedding(data)
                    predicts = classifier(features)
                    accuracy = top1accuracy(predicts.argmax(dim=1), label)

                    acc_new.update(accuracy.item(), data.size(0))
            all_acc_new.append(acc_new.avg)
        print(all_acc_new)

        for i in range(session + 1):
            if i == 0:
                classes = np.arange(args.num_classes)[:args.base_class]
            else:
                classes = np.arange(args.num_classes)[
                          (args.base_class + (i - 1) * args.way):(args.base_class + i * args.way)]
            if args.dataset == 'cifar100':
                test_for_each = args.Dataset.CIFAR100(root=args.data_folder, train=False, index=classes,
                                                      base_sess=False)
            elif args.dataset == 'mini_imagenet':

                test_for_each = args.Dataset.MiniImageNet(root=args.data_folder, train=False,
                                                          index=classes)
            else:
                test_for_each = args.Dataset.CUB200(root=args.data_folder, train=False, index=classes)

            testloader2 = torch.utils.data.DataLoader(dataset=test_for_each, batch_size=args.batch_size_inference,
                                                      shuffle=False, pin_memory=True)
            model.eval()
            acc2 = AverageMeter('Acc@1', ':6.2f')
            with t.no_grad():
                for i, batch in enumerate(testloader2):
                    data, label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
                    features = model.embedding(data)

                    # query_features = features.cpu().numpy()
                    predicts = classifier(features)
                    accuracy = top1accuracy(predicts.argmax(dim=1), label)

                    acc2.update(accuracy.item(), data.size(0))
            acc_up2now.append(acc2.avg)
        print(acc_up2now)

        with open(logs_dir, 'a', encoding='utf-8') as f1:
            f1.write(f'{acc_up2now} Acc each session:\t{acc_each_session}\t Novel classes Acc:{all_acc_new}\n')

    # Stage 5: Fill up prototypes again
    model.eval()

    # model.reset_prototypes(args)
    # model.update_prototypes_feat(feat, label, nways_session)


    # Stage 6: Optional EM compression
    if args.em_compression == "hrr":
        model.hrr_superposition(nways_session, args.em_compression_nsup)

    if session == 0:
        all_cov = base_cov
        all_proto = base_prototype
        acc_each_session = best_acc1
    else:
        all_cov = cov_saver
        all_proto = prototype_saver
        acc_each_session = acc.avg
    return all_proto, all_cov, acc_each_session







def feat_replay(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov, test_loader, best_acc1):
    ######use saved mean and covariance perform clustering##########
    losses = AverageMeter('Loss')
    criterion = myCosineLoss(args.retrain_act)
    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)
    sinkhorn = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)



    classifier = MultiClassLogisticRegression(input_dim=args.dim_features, output_dim=nways_session).cuda(args.gpu)
    acc = AverageMeter('Acc@1', ':6.2f')
    # Stage 1: Compute feature representation of new data
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
            x_features = model.embedding(x)

            model.update_feat_replay(x, target)



    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()

    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]
    # Stage 3: Nuddging
    # model.nudge_prototypes(nways_session, writer, session, args.gpu)
    logs_dir = os.path.join(args.log_dir + '/' + 'test_log.txt')
    if session == 0:
        with open(logs_dir, 'a', encoding='utf-8') as f1:
            f1.write(f'\nSource Model:{args.resume}\n')
            f1.write(f'{best_acc1}\t{best_acc1}\n')

    if session > 0:
        nways_session = args.base_class + session * args.way
        oways_session = args.base_class + (session - 1) * args.way

        c_proto = model.key_mem.data

        base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)

        # base_prototype = model.key_mem.data[:args.base_class].cpu().numpy()
        # base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)

        cost, Pi, C = sinkhorn(base_torch[:args.base_class], c_proto[args.base_class:nways_session])


        c_proto = c_proto.cpu().numpy()

        # base_prototype = c_proto[:oways_session]

        cov_saver = base_cov
        prototype_saver = base_prototype



        sampled_feature_old = []
        sampled_label_old = []


        sample_num_old =400

        for idx in range(args.base_class):
            sampled_feature_old.append(np.random.multivariate_normal(mean=base_prototype[idx], cov=base_cov[idx], size=sample_num_old))
            sampled_label_old.extend([idx] *sample_num_old)


        sampled_feature_old = np.array(sampled_feature_old).reshape(args.base_class*sample_num_old, -1)
        sampled_label_old = np.array(sampled_label_old)

        sampled_feature_new = []
        sampled_label_new = []
        sample_num = 400

        for i in range(args.base_class,nways_session):

            mean, cov = distribution_calibration_dan(c_proto[i], Pi[:, i-args.base_class], base_prototype[:args.base_class], base_cov[:args.base_class],
                                                     n_lsamples=args.way)

            sampled_feature_new.append(np.random.multivariate_normal(mean=mean, cov=cov, size=sample_num))

            sampled_label_new.extend([i] *sample_num)
            cov_saver = np.concatenate([cov_saver, np.expand_dims(cov,0)])
            prototype_saver = np.concatenate([prototype_saver, np.expand_dims(mean,0)])

        sampled_feature_new = np.array(sampled_feature_new).reshape((nways_session-args.base_class)*sample_num, -1)
        sampled_label_new = np.array(sampled_label_new)

        sampled_feature_all = np.concatenate([sampled_feature_old, sampled_feature_new], axis=0)
        sampled_label_all = np.concatenate([sampled_label_old, sampled_label_new], axis=0)

        # sampled_feature_all = sampled_feature_new
        # sampled_label_all = sampled_label_new

        sampled_feature_all = torch.from_numpy(sampled_feature_all).cuda(args.gpu).float()
        sampled_feature_all = torch.concat([sampled_feature_all,x_features],dim=0)

        sampled_label_all = torch.from_numpy(sampled_label_all).cuda(args.gpu)
        sampled_label_all = torch.concat([sampled_label_all,target], dim=0)

        # sampled_feature_new = sampled_feature_new.reshape(5 * sample_num, -1)
        # sampled_feature_new = torch.from_numpy(sampled_feature_new).cuda(args.gpu).float()
        # sampled_label_new = torch.from_numpy(sampled_label_new).cuda(args.gpu)
        # sampled_feature_new = torch.concat([sampled_feature_new,x_features],dim=0)
        # sampled_label_new = torch.concat([sampled_label_new, target], dim=0)

        num_epochs = 10000
        optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, nesterov=args.SGDnesterov,
                            weight_decay=args.SGDweight_decay, momentum=args.SGDmomentum)
        classifier.train()

        for epoch in range(num_epochs):

            # outputs = classifier(sampled_feature_all)
            # loss = nn.CrossEntropyLoss()(outputs, sampled_label_all)
            outputs = classifier(sampled_feature_all)

            # proxy = classifier.linear
            # features = sampled_feature_all
            # loss_pcl = PCLoss(num_classes=nways_session, scale=12)(features, sampled_label_all, proxy)
            old_class_weights = torch.ones(args.base_class).cuda(args.gpu)
            novel_class_weights =torch.ones(nways_session-args.base_class).cuda(args.gpu)

            weights = torch.cat([old_class_weights,novel_class_weights])
            loss = nn.CrossEntropyLoss()(outputs, sampled_label_all)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # classifier.linear.data[:args.base_class] = model.classifier.data
        model.eval()
        classifier.eval()


        logs_dir = os.path.join(args.log_dir + '/' + 'test_log.txt')
        all_acc = []
        with t.no_grad():
            for i, batch in enumerate(test_loader):
                data, label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]

                features = model.embedding(data)

                # query_features = features.cpu().numpy()
                predicts = classifier(features)

                # predicts = torch.from_numpy(predicts).cuda(args.gpu)
                accuracy = top1accuracy(predicts.argmax(dim=1), label)
                # losses.update(loss.item(), data.size(0))
                acc.update(accuracy.item(), data.size(0))

        acc_each_session = acc.avg
        print("Session {:} Testing Acc: {:.2f}%".format(session, acc_each_session))
        all_acc.append(acc_each_session)
        acc_up2now = []

        for i in range(session + 1):
            if i == 0:
                classes = np.arange(args.num_classes)[:args.base_class]
            else:
                classes = np.arange(args.num_classes)[
                          (args.base_class + (i - 1) * args.way):(args.base_class + i * args.way)]
            if args.dataset == 'cifar100':
                test_for_each = args.Dataset.CIFAR100(root=args.data_folder, train=False, index=classes,
                                                      base_sess=False)
            elif args.dataset == 'mini_imagenet':

                test_for_each = args.Dataset.MiniImageNet(root=args.data_folder, train=False,
                                                          index=classes)
            else:
                test_for_each = args.Dataset.CUB200(root=args.data_folder, train=False, index=classes)

            testloader2 = torch.utils.data.DataLoader(dataset=test_for_each, batch_size=args.batch_size_inference,
                                                      shuffle=False, pin_memory=True)
            model.eval()
            acc2 = AverageMeter('Acc@1', ':6.2f')
            with t.no_grad():
                for i, batch in enumerate(testloader2):
                    data, label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
                    features = model.embedding(data)

                    # query_features = features.cpu().numpy()
                    predicts = classifier(features)
                    accuracy = top1accuracy(predicts.argmax(dim=1), label)

                    acc2.update(accuracy.item(), data.size(0))
            acc_up2now.append(acc2.avg)
        print(acc_up2now)

        with open(logs_dir, 'a', encoding='utf-8') as f1:
            f1.write(f'{acc_up2now}\t{acc_each_session}\n')

        # if session == 0:
        #     with open(logs_dir, 'a', encoding='utf-8') as f1:
        #         f1.write(f'\nSource Model:{args.resume}\n')
        #         f1.write(f'{acc_up2now}\t{acc_each_session}\n')


        # if session == args.sessions - 1:
        #     mean_acc = np.mean(all_acc)
        #     with open(logs_dir, 'a', encoding='utf-8') as f1:
        #         f1.write(f'Mean Acc for this run is: {mean_acc}\t Each Session Acc is{all_acc} \n')
        #         print((f'Mean Acc for this run is: {mean_acc}\n Each Session Acc is{all_acc} \n'))
        #


    # Stage 5: Fill up prototypes again
    model.eval()

    # model.reset_prototypes(args)
    # model.update_prototypes_feat(feat, label, nways_session)


    # Stage 6: Optional EM compression
    if args.em_compression == "hrr":
        model.hrr_superposition(nways_session, args.em_compression_nsup)

    if session == 0:
        all_cov = base_cov
        all_proto = base_prototype
        acc_each_session = best_acc1
    else:
        all_cov = cov_saver
        all_proto = prototype_saver
        acc_each_session = acc.avg
    return all_proto, all_cov, acc_each_session




def proto_align(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov):
    '''
    Alignment of FC using MSE Loss and feature replay
    '''

    losses = AverageMeter('Loss')

    # criterion = PCLoss(num_classes=nways_session, scale=12)
    criterion = myCosineLoss(args.retrain_act)

    # criterion = t.nn.functional.pairwise_distance(input1, input2, p=2)
    # criterion = nn.CrossEntropyLoss()

    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)


    sampled_data, sampled_labels = [],[]

    # Stage 1: Compute feature representation of new data
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
            model.update_feat_replay(x, target)


    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()

    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]


    # Stage 3: Nuddging
    model.nudge_prototypes(nways_session, writer, session, args.gpu)

    # Bipolarize prototypes in Mode 2
    if args.bipolarize_prototypes:
        model.bipolarize_prototypes()

    # Stage 4: Update Retraining the FC

    model.embedding.fc.train()







    if session > 0:

        c_proto = model.key_mem.data
        c_proto = c_proto.cpu().numpy()
        for i in range(args.base_class,nways_session):
            mean, cov = distribution_calibration(c_proto[i], base_prototype, base_cov, k=2)
            proto_temp = np.random.multivariate_normal(mean=mean, cov=cov, size=1)
            c_proto[i] = proto_temp

        c_proto = torch.from_numpy(c_proto).float().cuda(args.gpu)

        model.key_mem.data = c_proto


        for epoch in range(args.retrain_iter):

            optimizer.zero_grad()
            support = model.get_support_feat(feat)

            # label2 = torch.arange(nways_session).cuda(args.gpu, non_blocking=True)
            # loss_pcl = PCLoss(num_classes=nways_session, scale=12)(support[:nways_session], label2,
            #                                                                        model.key_mem.data[:nways_session])



            # c_proto = model.key_mem.data
            # c_proto = c_proto.cpu().numpy()
            # for i in range(nways_session):
            #     mean, cov = distribution_calibration(c_proto[i], base_prototype, base_cov, k=2)
            #     proto_temp = np.random.multivariate_normal(mean=mean, cov=cov, size=1)
            #     c_proto[i] = proto_temp
            #
            #
            # c_proto = torch.from_numpy(c_proto).float().cuda(args.gpu)
            #
            # model.key_mem.data = c_proto
            loss_cls = criterion(support[:nways_session], model.key_mem.data[:nways_session])
            # loss_cor = CorrelationLoss(num_ways=nways_session)(model.key_mem.data[:nways_session])

            # p_out = model.pseduo_feature_inference(sampled_data)
            #
            # sampled_labels_onehot = F.one_hot(sampled_labels, num_classes=args.num_classes).float()
            # loss_n_cls = nn.CrossEntropyLoss()(p_out, sampled_labels_onehot)

            # loss = 0.05 * loss_pcl
            loss = 1.0 * loss_cls
            # loss =  loss_n_cls



            # Backpropagation
            loss.backward()
            optimizer.step()

            writer.add_scalar('retraining/loss_sess{:}'.format(session), loss.item(), epoch)

    # Stage 5: Fill up prototypes again
    model.eval()
    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)


    # Stage 6: Optional EM compression
    if args.em_compression == "hrr":
        model.hrr_superposition(nways_session, args.em_compression_nsup)




# def proto_align_v2(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov):
#     '''
#     Alignment of FC using MSE Loss and feature replay
#     '''
#
#     losses = AverageMeter('Loss')
#
#
#     criterion = myCosineLoss(args.retrain_act)
#
#
#     dataset = myRetrainDataset(data[0], data[1])
#     dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)
#     sinkhorn = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)
#     sinkhorn_multi = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)
#     # Stage 1: Compute feature representation of new data
#     model.eval()
#     with t.no_grad():
#         for x, target in dataloader:
#             x = x.cuda(args.gpu, non_blocking=True)
#             target = target.cuda(args.gpu, non_blocking=True)
#             model.update_feat_replay(x, target)
#
#
#     # Stage 2: Compute prototype based on GAAM
#     feat, label = model.get_feat_replay()
#
#     model.reset_prototypes(args)
#     model.update_prototypes_feat(feat, label, nways_session)
#
#     # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]
#
#     before_caliber =  model.key_mem.data[:nways_session]
#     # Stage 3: Nuddging
#     # model.nudge_prototypes(nways_session, writer, session, args.gpu)
#
#     # Bipolarize prototypes in Mode 2
#     if args.bipolarize_prototypes:
#         model.bipolarize_prototypes()
#
#     # Stage 4: Update Retraining the FC
#
#     model.embedding.fc.train()
#
#
#     if session > 0:
#         nways_session = args.base_class + session * args.way
#         oways_session = args.base_class + (session - 1) * args.way
#
#         old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]
#
#         base_proxy = model.key_mem.data[:model.args.base_class]
#
#         c_proto = torch.sign(model.key_mem.data)
#         # c_proto = model.key_mem.data
#         # c_proto = Tanh10x()(model.key_mem.data)
#
#
#
#         base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)
#
#         # c_proto = Tanh10x()(model.key_mem.data)
#         # c_proto[oways_session:nways_session] = Tanh10x()(c_proto[oways_session:nways_session])
#         cost, Pi, C = sinkhorn(base_torch, c_proto[oways_session:nways_session])
#
#
#         c_proto = c_proto.cpu().numpy()
#
#
#         for i in range(oways_session,nways_session):
#
#             mean, cov = distribution_calibration_dan(c_proto[i], Pi[:, i-oways_session], base_prototype, base_cov,
#                                                      n_lsamples=args.way )
#
#
#             # proto_temp = np.random.multivariate_normal(mean=mean, cov=cov, size=1)
#             #
#             # c_proto[i] = proto_temp
#
#             proto_temp = np.random.multivariate_normal(mean=mean, cov=cov, size=300)
#
#
#
#             # probabi = Pi[:, i-oways_session].cpu()
#             # proab_reshape = np.repeat(args.way* probabi.numpy(), base_prototype[0].shape[0], axis=0).reshape(len(base_prototype), base_prototype[0].shape[0])
#             # preivous_mean = np.sum(proab_reshape * np.concatenate([base_prototype[:]]),axis=0)
#             #
#             # preivous_mean = torch.from_numpy(preivous_mean).float().cuda(args.gpu)
#             # proto_temp2 = torch.from_numpy(proto_temp).float().cuda(args.gpu)
#             # cost2, Pi2, C2 = sinkhorn_multi(preivous_mean, proto_temp2)
#             # print(Pi2.shape)
#
#             proto_temp2 = torch.from_numpy(proto_temp).float().cuda(args.gpu)
#
#             # reference = torch.from_numpy(mean).float().cuda(args.gpu)
#             similar = cosine_similarity_multi(proto_temp2, base_proxy, rep=args.representation)
#
#             similar_most = torch.argmin(-similar.sum(dim=1))
#
#             c_proto[i] = proto_temp[similar_most.cpu().numpy()]
#             # c_proto[i] = mean
#         c_proto = torch.from_numpy(c_proto).float().cuda(args.gpu)
#
#         model.key_mem.data = c_proto
#
#         model.nudge_prototypes(nways_session, writer, session, args.gpu)
#
#
#         for epoch in range(args.retrain_iter):
#
#             optimizer.zero_grad()
#             support = model.get_support_feat(feat)
#
#
#             loss_cls = criterion(support[:nways_session], model.key_mem.data[:nways_session])
#
#
#
#             label2 = torch.arange(oways_session, nways_session).cuda(args.gpu, non_blocking=True)
#
#             # cost2, Pi2, C2 = sinkhorn(model.key_mem.data[:oways_session], model.key_mem.data[oways_session:nways_session])
#
#             #
#
#             Pi2 = model.pseduo_feature_inference(support[oways_session:nways_session])
#
#
#             loss_pcl = PCALoss(num_classes=nways_session, scale=12) \
#                 (support[:nways_session], support[oways_session:nways_session],label2, model.key_mem.data[:nways_session], old_proxy, base_proxy,Pi2.transpose(1,0), mweight=2)
#
#
#
#             # loss = 1.0 * loss_cls +0.00*loss_pcl
#             loss = 1.0 * loss_cls
#             # loss =  loss_n_cls
#
#
#
#             # Backpropagation
#             loss.backward()
#             optimizer.step()
#
#             writer.add_scalar('retraining/loss_sess{:}'.format(session), loss.item(), epoch)
#
#     # plot_confusion_support(model.key_mem.data[:nways_session].cpu(),
#     #                        savepath="{:}/relu2{:}".format(args.log_dir, str(session)))
#
#
#
#     # if session == 1:
#     #     np.random.seed(42)
#     #
#     #
#     #     X = model.key_mem.data[:nways_session].cpu().numpy()
#     #     X2 =  before_caliber.cpu().numpy()[100:110]  # 示例数据，50 个样本，每个样本有 10 个特征
#     #
#     #     X = np.concatenate((X,X2), axis=0)
#     #
#     #
#     #     # 使用 t-SNE 进行降维
#     #     tsne = TSNE(n_components=2, random_state=42)
#     #     X_tsne = tsne.fit_transform(X)
#     #
#     #     # 绘制 t-SNE 图
#     #     num_samples = X_tsne.shape[0]
#     #     colors = plt.cm.get_cmap('Pastel1', num_samples)
#     #     color_indices = np.arange(100)  # 前 60 个样本的索引
#     #
#     #     # 绘制 t-SNE 图
#     #     plt.figure(figsize=(8, 6))
#     #
#     #     # 根据颜色映射绘制前 60 个样本的散点图
#     #     plt.scatter(X_tsne[color_indices, 0], X_tsne[color_indices, 1], c=color_indices, cmap=colors, marker='o', s=50, label='Base Prototypes')
#     #
#     #     plt.scatter(X_tsne[100:110, 0], X_tsne[100:110, 1], c='b', marker='o', s=50, label='Calibrated new Prototypes')
#     #
#     #
#     #     # 绘制剩下的样本为蓝色
#     #     plt.scatter(X_tsne[110:, 0], X_tsne[110:, 1], c='r', marker='o', s=50, label='Original new Prototypes')
#     #
#     #     plt.title('t-SNE Visualization')
#     #     plt.xlabel('t-SNE Dimension 1')
#     #     plt.ylabel('t-SNE Dimension 2')
#     #     plt.legend()
#     #
#     #     save_path = args.log_dir + 'tsne_plot_cub_all.pdf'
#     #     # 保存图为 PDF 文件
#     #     plt.savefig(save_path, format='pdf')
#     # Stage 5: Fill up prototypes again
#     model.eval()
#     model.reset_prototypes(args)
#     model.update_prototypes_feat(feat, label, nways_session)
#
#
#     # Stage 6: Optional EM compression
#     if args.em_compression == "hrr":
#         model.hrr_superposition(nways_session, args.em_compression_nsup)
def proto_align_v2(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov):

    losses = AverageMeter('Loss')
    criterion = myCosineLoss(args.retrain_act)
    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)
    sinkhorn = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)

    sinkhorn_multi = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)

    # Stage 1: Compute feature representation of new data
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)
            all_feature = model.embedding(x)
            target = target.cuda(args.gpu, non_blocking=True)
            model.update_feat_replay(x, target)


    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()

    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]
    # Stage 3: Nuddging
    # model.nudge_prototypes(nways_session, writer, session, args.gpu)

    # Bipolarize prototypes in Mode 2
    if args.bipolarize_prototypes:
        model.bipolarize_prototypes()

    # Stage 4: Update Retraining the FC

    model.embedding.fc.train()


    if session > 0:
        nways_session = args.base_class + session *args.way
        oways_session = args.base_class + (session - 1) * args.way


        # c_proto = model.key_mem.data

        old_proxy = model.key_mem.data[model.args.base_class:model.args.base_class + (session) * args.way]

        base_proxy = model.key_mem.data[:model.args.base_class]


        base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)
        # base_torch = model.key_mem.data[:args.base_class]

        # c_proto = Tanh10x()(model.key_mem.data)
        # c_proto = torch.sign(model.key_mem.data)
        c_proto = model.key_mem.data

        cost, Pi, C = sinkhorn(base_torch, c_proto[oways_session:nways_session])


        c_proto = c_proto.cpu().numpy()
        for i in range(oways_session,nways_session):

            mean, cov = distribution_calibration_dan(c_proto[i], Pi[:, i-oways_session], base_prototype, base_cov,
                                                     n_lsamples=args.way)

            proto_temp = np.random.multivariate_normal(mean=mean, cov=cov, size=args.sample_num)

            proto_temp3 = torch.from_numpy(proto_temp).float().cuda(args.gpu)

            cost2, Pi2, C2 = sinkhorn_multi(base_torch,proto_temp3)

            new_temp = torch.matmul(Pi[:, i-oways_session], torch.matmul(Pi2, proto_temp3))

            c_proto[i] = new_temp.cpu().numpy()
            # c_proto[i] = new_temp.cpu().numpy()

        c_proto = torch.from_numpy(c_proto).float().cuda(args.gpu)

        # model.key_mem.data = Tanh10x()(c_proto)
        model.key_mem.data = c_proto

        model.nudge_prototypes(nways_session, writer, session, args.gpu)



        for epoch in range(args.retrain_iter):

            optimizer.zero_grad()
            support = model.get_support_feat(feat)
            loss_cls = criterion(support[:nways_session], model.key_mem.data[:nways_session])
            loss = 1.0 * loss_cls
            # Backpropagation
            loss.backward()
            optimizer.step()
            writer.add_scalar('retraining/loss_sess{:}'.format(session), loss.item(), epoch)


    # Stage 5: Fill up prototypes again
    model.eval()
    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)


    # plot_confusion_support(model.key_mem.data.cpu(),
    #                     savepath="{:}/relu2{:}".format(args.log_dir, str(session)))

    # Stage 6: Optional EM compression
    if args.em_compression == "hrr":
        model.hrr_superposition(nways_session, args.em_compression_nsup)





def proto_align_final(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov):

    losses = AverageMeter('Loss')
    criterion = myCosineLoss(args.retrain_act)
    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)
    sinkhorn = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)

    sinkhorn_multi = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)

    # Stage 1: Compute feature representation of new data
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)
            all_feature = model.embedding(x)
            target = target.cuda(args.gpu, non_blocking=True)
            model.update_feat_replay(x, target)


    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()

    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]
    # Stage 3: Nuddging
    # model.nudge_prototypes(nways_session, writer, session, args.gpu)

    # Bipolarize prototypes in Mode 2
    if args.bipolarize_prototypes:
        model.bipolarize_prototypes()

    # Stage 4: Update Retraining the FC

    model.embedding.fc.train()


    if session > 0:
        nways_session = args.base_class + session *args.way
        oways_session = args.base_class + (session - 1) * args.way


        # c_proto = model.key_mem.data

        # old_proxy = model.key_mem.data[model.args.base_class:model.args.base_class + (session) * args.way]
        #
        # base_proxy = model.key_mem.data[:model.args.base_class]


        # base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)
        base_torch = model.key_mem.data[:args.base_class]

        base_prototype =base_torch.cpu().numpy()

        # c_proto = Tanh10x()(model.key_mem.data)
        # c_proto = torch.sign(model.key_mem.data)
        c_proto = model.key_mem.data

        cost, Pi, C = sinkhorn(base_torch, c_proto[oways_session:nways_session])


        c_proto = c_proto.cpu().numpy()
        for i in range(oways_session,nways_session):

            mean, cov = distribution_calibration_dan(c_proto[i], Pi[:, i-oways_session], base_prototype, base_cov,
                                                     n_lsamples=args.way)

            proto_temp = np.random.multivariate_normal(mean=mean, cov=cov, size=args.sample_num)

            proto_temp3 = torch.from_numpy(proto_temp).float().cuda(args.gpu)

            cost2, Pi2, C2 = sinkhorn_multi(base_torch,proto_temp3)

            new_temp = torch.matmul(Pi[:, i-oways_session], torch.matmul(Pi2, proto_temp3))

            c_proto[i] = new_temp.cpu().numpy()


        c_proto = torch.from_numpy(c_proto).float().cuda(args.gpu)

        # model.key_mem.data = Tanh10x()(c_proto)
        model.key_mem.data = c_proto

        model.nudge_prototypes(nways_session, writer, session, args.gpu)



        for epoch in range(args.retrain_iter):

            optimizer.zero_grad()
            support = model.get_support_feat(feat)
            loss_cls = criterion(support[:nways_session], model.key_mem.data[:nways_session])
            loss = 1.0 * loss_cls
            # Backpropagation
            loss.backward()
            optimizer.step()
            writer.add_scalar('retraining/loss_sess{:}'.format(session), loss.item(), epoch)


    # Stage 5: Fill up prototypes again
    model.eval()
    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)


    # plot_confusion_support(model.key_mem.data.cpu(),
    #                     savepath="{:}/relu2{:}".format(args.log_dir, str(session)))

    # Stage 6: Optional EM compression
    if args.em_compression == "hrr":
        model.hrr_superposition(nways_session, args.em_compression_nsup)











def distribution_calibration(query, base_means, base_cov, k,alpha=0.21):
    dist = []
    for i in range(len(base_means)):
        dist.append(np.linalg.norm(query-base_means[i]))
    index = np.argpartition(dist, k)[:k]

    mean = np.concatenate([np.array(base_means)[index], query[np.newaxis, :]])

    calibrated_mean = np.mean(mean, axis=0)
    calibrated_cov = np.mean(np.array(base_cov)[index], axis=0)+alpha

    return calibrated_mean, calibrated_cov




def distribution_calibration_dan(prototype, probabi, base_means, base_cov, n_lsamples, alpha=0.21, lambd=0.3, k=10):
    # index = np.argsort(-probabi.numpy())
    dim = base_means[0].shape[0]
    calibrated_mean = 0
    calibrated_cov = 0

    probabi = probabi.cpu()

    proab_reshape = np.repeat(n_lsamples * probabi.numpy(), dim, axis=0).reshape(len(base_means), dim)
    calibrated_mean = (1 - lambd) * np.sum(proab_reshape * np.concatenate([base_means[:]]), axis=0) + lambd * prototype
    #
    proab_reshape_conv = np.repeat(n_lsamples * probabi.numpy(), dim * dim, axis=0).reshape(len(base_means), dim, dim)
    calibrated_cov = np.sum(proab_reshape_conv * np.concatenate([base_cov[:]]), axis=0) + alpha
    return calibrated_mean, calibrated_cov






def validation(model,criterion,dataloader, args,nways_session=None):
    losses = AverageMeter('Loss', ':.4e')
    acc = AverageMeter('Acc@1', ':6.2f')

    sim_conf = avg_sim_confusion(args.num_classes,nways_session)
    model.eval()
    with t.no_grad(): 
        for i, batch in enumerate(dataloader):
            data, label = [_.cuda(args.gpu,non_blocking=True) for _ in batch]

            output = model(data)
            loss = criterion(output,label)
            accuracy = top1accuracy(output.argmax(dim=1),label)

            losses.update(loss.item(),data.size(0))
            acc.update(accuracy.item(),data.size(0))
            # if nways_session is not None:
            #     sim_conf.update(model.similarities.detach().cpu(),
            #                 F.one_hot(label.detach().cpu(), num_classes = args.num_classes).float())
    # Plot figure if needed
    fig = sim_conf.plot() if nways_session is (not None) else None
    return losses.avg, acc.avg, fig




# def validation_each(model, criterion, dataloader, args, nways_session=None, specific_classes=None):
#     losses = AverageMeter('Loss', ':.4e')
#     overall_acc = AverageMeter('Overall Acc', ':6.2f')
#     specific_acc = AverageMeter('Specific Acc', ':6.2f')
#
#     sim_conf = avg_sim_confusion(args.num_classes, nways_session)
#     model.eval()
#     with torch.no_grad():
#         for i, batch in enumerate(dataloader):
#             data, label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
#
#             output = model(data)
#             loss = criterion(output, label)
#
#             if specific_classes is not None:
#                 # Filter output and label for specific classes
#                 mask = label.unsqueeze(1).eq(torch.tensor(specific_classes, device=label.device).unsqueeze(0))
#                 output_specific = output[mask].view(-1, len(specific_classes))
#                 label_specific = label[mask].view(-1)
#                 specific_accuracy = top1accuracy(output_specific.argmax(dim=1), label_specific)
#                 specific_acc.update(specific_accuracy.item(), data.size(0))
#
#             accuracy = top1accuracy(output.argmax(dim=1), label)
#             overall_acc.update(accuracy.item(), data.size(0))
#
#             losses.update(loss.item(), data.size(0))
#             # if nways_session is not None:
#             #     sim_conf.update(model.similarities.detach().cpu(),
#             #                 F.one_hot(label.detach().cpu(), num_classes=args.num_classes).float())
#     # Plot figure if needed
#     fig = sim_conf.plot() if nways_session is not None else None
#
#     return losses.avg, overall_acc.avg, specific_acc.avg, fig


def validation_onehot(model,criterion,dataloader, args, num_classes):
    #  

    losses = AverageMeter('Loss', ':.4e')
    acc = AverageMeter('Acc@1', ':6.2f')

    model.eval()

    with t.no_grad(): 
        for i, batch in enumerate(dataloader):
            data, label = [_.cuda(args.gpu,non_blocking=True) for _ in batch]
            label = F.one_hot(label, num_classes = num_classes).float()

            output = model(data)
            loss = criterion(output,label)
            
            _, _, _, _, accuracy = process_result(
                output,label)

            losses.update(loss.item(),data.size(0))
            acc.update(accuracy.item()*100,data.size(0))
    
    return losses.avg, acc.avg

# --------------------------------------------------------------------------------------------------
# Interpretation
# --------------------------------------------------------------------------------------------------
def process_result(predictions, actual):
    predicted_labels = t.argmax(predictions, dim=1)
    actual_labels = t.argmax(actual, dim=1)

    accuracy = predicted_labels.eq(actual_labels).float().mean(0,keepdim=True)
    # TBD implement those uncertainties
    predicted_certainties =0#
    actual_certainties = 0 #
    return predicted_labels, predicted_certainties, actual_labels, actual_certainties, accuracy


def process_dictionary(dict):
    # Convert the dictionary to a sorted list
    dict_list = sorted(list(dict.items()))

    # Convert the dictionary into a table
    keys, values = zip(*dict_list)
    values = [repr(value) for value in values]
    dict_table = np.vstack((np.array(keys), np.array(values))).T

    return dict_list, dict_table

# --------------------------------------------------------------------------------------------------
# Summaries
# --------------------------------------------------------------------------------------------------
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar',savedir=''):
    t.save(state, savedir+'/'+filename)
    if is_best:
        shutil.copyfile(savedir+'/'+filename, savedir+'/'+'model_best.pth.tar')



def load_checkpoint(model,optimizer,scheduler,args):        

    # First priority: load checkpoint from log_dir 
    if os.path.isfile(args.log_dir+ '/checkpoint.pth.tar'):
        resume = args.log_dir+ '/checkpoint.pth.tar'
        print("=> loading checkpoint '{}'".format(resume))
        if args.gpu is None:
            checkpoint = t.load(resume)
        else:
            # Map model to be loaded to specified single args.gpu.
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = t.load(resume, map_location=loc)
        start_train_iter = int(checkpoint['train_iter'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        best_acc1 = checkpoint['best_acc1']
        model.load_state_dict(checkpoint['state_dict'])

        print("=> loaded checkpoint '{}' (train_iter {})"
              .format(args.log_dir, checkpoint['train_iter']))
        print('previous acc', best_acc1)
        prototype, cov, classlabel = None, None, None


    # Second priority: load from pretrained model
    # No scheduler and no optimizer loading here.  
    elif os.path.isfile(args.resume+'/model_best.pth.tar'):
        resume = args.resume+'/model_best.pth.tar'
        print("=> loading pretrain checkpoint '{}'".format(resume))
        if args.gpu is None:
            checkpoint = t.load(resume)
        else:
            # Map model to be loaded to specified single args.gpu.
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = t.load(resume, map_location=loc)
        start_train_iter = 0 
        best_acc1 = 0
        model.load_state_dict(checkpoint['state_dict'])
        best_acc2 = checkpoint['best_acc1']
        print('previous best acc',best_acc2)
        print("=> loaded pretrained checkpoint '{}' (train_iter {})"
              .format(args.log_dir, checkpoint['train_iter']))

        # prototype = checkpoint['prototype']
        # cov = checkpoint['cov']
        # classlabel = checkpoint['classlabel']


    else:
        start_train_iter=0
        best_acc1 = 0
        prototype, cov, classlabel = None, None, None
        print("=> no checkpoint found at '{}'".format(args.log_dir))
        print("=> no pretrain checkpoint found at '{}'".format(args.resume))




    return model, optimizer, scheduler, start_train_iter, best_acc1,
    # return model, optimizer, scheduler, start_train_iter, best_acc1, prototype,cov,classlabel




def load_checkpoint2(model,optimizer,scheduler,args):

    # First priority: load checkpoint from log_dir
    if os.path.isfile(args.log_dir+ '/checkpoint.pth.tar'):
        resume = args.log_dir+ '/checkpoint.pth.tar'
        print("=> loading checkpoint '{}'".format(resume))
        if args.gpu is None:
            checkpoint = t.load(resume)
        else:
            # Map model to be loaded to specified single args.gpu.
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = t.load(resume, map_location=loc)
        start_train_iter = int(checkpoint['train_iter'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        best_acc1 = checkpoint['best_acc1']
        model.load_state_dict(checkpoint['state_dict'])

        print("=> loaded checkpoint '{}' (train_iter {})"
              .format(args.log_dir, checkpoint['train_iter']))
        print('previous acc', best_acc1)
        prototype, cov, classlabel = None, None, None


    # Second priority: load from pretrained model
    # No scheduler and no optimizer loading here.
    elif os.path.isfile(args.resume+'/model_best.pth.tar'):
        resume = args.resume+'/model_best.pth.tar'
        print("=> loading pretrain checkpoint '{}'".format(resume))
        if args.gpu is None:
            checkpoint = t.load(resume)
        else:
            # Map model to be loaded to specified single args.gpu.
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = t.load(resume, map_location=loc)
        start_train_iter = 0
        best_acc1 = 0
        model.load_state_dict(checkpoint['state_dict'])
        best_acc2 = checkpoint['best_acc1']
        best_acc1 = best_acc2

        print('previous best acc',best_acc2)
        print("=> loaded pretrained checkpoint '{}' (train_iter {})"
              .format(args.log_dir, checkpoint['train_iter']))

        prototype = checkpoint['prototype']
        cov = checkpoint['cov']
        classlabel = checkpoint['classlabel']


    else:
        start_train_iter=0
        best_acc1 = 0
        prototype, cov, classlabel = None, None, None
        print("=> no checkpoint found at '{}'".format(args.log_dir))
        print("=> no pretrain checkpoint found at '{}'".format(args.resume))




    # return model, optimizer, scheduler, start_train_iter, best_acc1,
    return model, optimizer, scheduler, start_train_iter, best_acc1, prototype,cov,classlabel

# --------------------------------------------------------------------------------------------------
# Some Pytorch helper functions (might be removed from this file at some point)
# --------------------------------------------------------------------------------------------------





def convert_toonehot(label): 
    '''
    Converts index to one-hot. Removes rows with only zeros, such that 
    the tensor has shape (B,num_ways)
    '''
    label_onehot = F.one_hot(label)
    label_onehot = label_onehot[:,label_onehot.sum(dim=0)!=0]
    return label_onehot.type(t.FloatTensor)

def top1accuracy(pred, target):
    """Computes the precision@1"""
    batch_size = target.size(0)

    correct = pred.eq(target).float().sum(0)
    return correct.mul_(100.0 / batch_size)



def accuracy_for_each_task(pred, target, ):
    """Computes the precision@1"""
    batch_size = target.size(0)

    correct = pred.eq(target).float().sum(0)
    return correct.mul_(100.0 / batch_size)



class myRetrainDataset(Dataset):
    def __init__(self, x,y):
        self.x = x
        self.y = y
       
    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


class MultiClassLogisticRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MultiClassLogisticRegression, self).__init__()
        # self.linear = nn.Linear(input_dim, output_dim)
        self.linear = nn.Parameter(t.FloatTensor(output_dim, input_dim))

        nn.init.kaiming_uniform_(self.linear, mode='fan_out', a=math.sqrt(5))


    def forward(self, x):
        # return self.linear(x)
        # a_normalized = F.normalize(x, dim=1)
        # b_normalized = F.normalize(self.linear, dim=1)
        similiarity = F.linear(x, self.linear)
        return  similiarity